﻿using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;
using Pfz.Caching;
using Pfz.Extensions;
using Pfz.Serialization;
using Pfz.Threading;

namespace Pfz.Remoting
{
	/// <summary>
	/// Class responsible for creating many channels inside another stream.
	/// This is used by the remoting framework, so each thread has it's own
	/// channel inside a single tcp/ip connection.
	/// </summary>
	public class StreamChanneller:
		ThreadSafeExceptionAwareDisposable,
		IChanneller,
		IGarbageCollectionAware
	{
		#region Private and internal fields
			private Stream _stream;
			
			private Dictionary<int, StreamChannel> _channels = new Dictionary<int, StreamChannel>();
			private Dictionary<int, ManagedManualResetEvent> _awaitingChannels = new Dictionary<int, ManagedManualResetEvent>();
			
			private readonly BinarySerializer _mainChannelSerializer = _CreateSerializer();
			private StreamChannel _mainChannel;
			
			internal int _channelBufferSize;
			internal readonly Queue<byte[]> _buffersToSend = new Queue<byte[]>();
			
			private int _nextChannelId;
		#endregion
		
		#region Constructors
			/// <summary>
			/// Creates the channeller for the specified stream and allows you to
			/// specify the buffer size. For tcp/ip stream, use the bigger value
			/// between receive an send buffer size.
			/// </summary>
			public StreamChanneller(Stream stream, int bufferSizePerChannel, string localEndpoint, string remoteEndpoint)
			{
				if (stream == null)
					throw new ArgumentNullException("stream");
					
				if (bufferSizePerChannel < 256)
					throw new ArgumentException("bufferSizePerChannel can't be less than 256 bytes", "bufferSizePerChannel");

				_stream = stream;
				_channelBufferSize = bufferSizePerChannel;
				LocalEndpoint = localEndpoint;
				RemoteEndpoint = remoteEndpoint;
			}

			private bool _started;
			/// <summary>
			/// Starts this channeller.
			/// </summary>
			public void Start()
			{
				CheckUndisposed();

				if (_started)
					throw new RemotingException("This channeller is already started.");

				_started = true;
					
				StreamChannel mainChannel = new StreamChannel(this);
				_mainChannel = mainChannel;
				_channels.Add(0, mainChannel);
					
				Thread threadReader = new Thread(_Reader);
				threadReader.IsBackground = true;
				threadReader.Name = "StreamChanneller reader.";
				threadReader.Start();
				
				Thread threadWriter = new Thread(_Writer);
				threadWriter.IsBackground = true;
				threadWriter.Name = "StreamChanneller writer.";
				threadWriter.Start();
				
				Thread threadMainChannel = new Thread(_MainChannel);
				threadMainChannel.IsBackground = true;
				threadMainChannel.Name = "StreamChanneller main channel.";
				threadMainChannel.Start();
				
				GCUtils.RegisterForCollectedNotification(this);
			}
		#endregion
		#region Dispose
			/// <summary>
			/// Disposes the channeller and the stream.
			/// </summary>
			/// <param name="disposing">true if called from Dispose() and false if called from destructor.</param>
			[SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", MessageId = "_writerEvent")]
			protected override void Dispose(bool disposing)
			{
				if (disposing)
				{
					GCUtils.UnregisterFromCollectedNotification(this);
						
					Disposer.Dispose(ref _stream);

					lock(DisposeLock)
						Monitor.Pulse(DisposeLock);
					
					var channels = _channels;
					if (channels != null)
					{
						_channels = null;
						
						foreach(StreamChannel channel in channels.Values)
							channel.Dispose(DisposeException);
					}
					
					var awaitingChannels = _awaitingChannels;
					if (awaitingChannels != null)
					{
						_awaitingChannels = null;
						foreach(var mre in awaitingChannels.Values)
							mre.Dispose();
					}
				}
					
 				base.Dispose(disposing);
	 			
 				if (disposing)
				{
					var disposedHandler = Disposed;
 					if (disposedHandler != null)
 						disposedHandler(this, EventArgs.Empty);
				}
			}
		#endregion
		#region _Collected
			void IGarbageCollectionAware.OnCollected()
			{
				try
				{
					lock(DisposeLock)
					{
						if (WasDisposed)
						{
							GCUtils.UnregisterFromCollectedNotification(this);
							return;
						}

						_awaitingChannels = new Dictionary<int, ManagedManualResetEvent>(_awaitingChannels);
						_channels = new Dictionary<int, StreamChannel>(_channels);
						_buffersToSend.TrimExcess();
					}
				}
				catch(OutOfMemoryException)
				{
					// ignore out of memory exception, as lists are kept intact if there
					// is no memory.
				}
			}
		#endregion
		
		#region Properties
			/// <summary>
			/// Gets the LocalEndpoint.
			/// </summary>
			public string LocalEndpoint { get; private set; }

			/// <summary>
			/// Gets the RemoteEndpoint.
			/// </summary>
			public string RemoteEndpoint { get; private set; }
		#endregion
		#region Methods
			#region CreateChannel
				/// <summary>
				/// Creates a channel, sending the serializableData parameter to the
				/// other side, so it can decide what to do with this channel before it
				/// gets used (this avoids an extra tcp/ip packet for small information).
				/// </summary>
				/// <param name="serializableData">Data to send to the other side.</param>
				/// <returns>A new channel.</returns>
				public StreamChannel CreateChannel(object serializableData = null)
				{
					try
					{
						int channelId = Interlocked.Increment(ref _nextChannelId);
						StreamChannel channel = new StreamChannel(this);
						channel._id = channelId;
						
						ClassChannelCreated channelCreated = new ClassChannelCreated();
						channelCreated.SenderChannelId = channelId;
						channelCreated.Data = serializableData;

						using(var manualResetEvent = new ManagedManualResetEvent())
						{
							bool lockTaken = false;
							try
							{
								Monitor.Enter(DisposeLock, ref lockTaken);

								CheckUndisposed();
								_channels.Add(channelId, channel);
						
								_awaitingChannels.Add(channelId, manualResetEvent);
								
								try
								{
									_mainChannelSerializer.Serialize(_mainChannel, channelCreated);
									_mainChannel.Flush();
								
									try
									{
									}
									finally
									{
										Monitor.Exit(DisposeLock);
										lockTaken = false;
									}

									manualResetEvent.WaitOne();

									Monitor.Enter(DisposeLock, ref lockTaken);
								
									CheckUndisposed();
								}
								finally
								{
									if (!lockTaken)
									{
										Monitor.Enter(DisposeLock);
										lockTaken = true;
									}

									if (!WasDisposed)
										_awaitingChannels.Remove(channelId);
								}
							}
							finally
							{
								if (lockTaken)
									Monitor.Exit(DisposeLock);
							}
						}
						
						return channel;
					}
					catch(Exception exception)
					{
						Dispose(exception);
							
						throw;
					}
				}
			#endregion
			#region _RemoveChannel
				internal void _RemoveChannel(int id, int remoteId)
				{
					ChannelRemoved channelRemoved = new ChannelRemoved();
					channelRemoved.ReceiverChannelId = remoteId;

					_channels.Remove(id);

					try
					{
						_mainChannelSerializer.Serialize(_mainChannel, channelRemoved);
						_mainChannel.Flush();
					}
					catch
					{
					}
				}
			#endregion
		
			#region _Reader
				private void _Reader()
				{
					try
					{
						byte[] headerBuffer = new byte[8];
						while(true)
						{
							_Read(headerBuffer, 8);
							
							int channelId = BitConverter.ToInt32(headerBuffer, 0);
							int messageSize = BitConverter.ToInt32(headerBuffer, 4);
							
							StreamChannel channel = null;
							lock(DisposeLock)
							{
								if (WasDisposed)
									return;

								_channels.TryGetValue(channelId, out channel);
							}
								
							if (channel == null)
							{
								_Discard(messageSize);
								continue;
							}

							int bytesLeft = messageSize;
							while (bytesLeft > 0)
							{
								if (WasDisposed)
									break;
							
								int count = bytesLeft;
								if (bytesLeft > _channelBufferSize)
									count = _channelBufferSize;
								
								byte[] messageBuffer;
								try
								{
									messageBuffer = new byte[count];
								}
								catch(Exception exception)
								{
									channel.Dispose(exception);
									
									continue;
								}
								
								_Read(messageBuffer, count);
								bytesLeft -= count;
								
								lock(channel.DisposeLock)
								{
									if (!channel.WasDisposed)
									{
										try
										{
											channel._inMessages.Enqueue(messageBuffer);
										}
										catch(Exception exception)
										{
											channel.Dispose(exception);
											
											continue;
										}

										Monitor.Pulse(channel.DisposeLock);
									}
								}
							}
						}
					}
					catch(Exception exception)
					{
						Dispose(exception);
					}
				}
			#endregion
			#region _Read
				private void _Read(byte[] buffer, int count)
				{
					var stream = _stream;
					if (stream == null)
					{
						var exception = new IOException("Stream closed.");
						Dispose(exception);
						throw exception;
					}

					int totalRead = 0;
					while(totalRead < count)
					{
						int read = _stream.Read(buffer, totalRead, count-totalRead);
						
						if (read == 0)
						{
							var exception = new IOException("Stream closed.");
							Dispose(exception);
							throw exception;
						}
						
						totalRead += read;
					}
				}
			#endregion
			#region _Discard
				private void _Discard(int bytesToDiscard)
				{
					int bufferSize = Math.Min(bytesToDiscard, _channelBufferSize);
					byte[] discardBuffer = new byte[bufferSize];

					int bytesLeft = bytesToDiscard;
					while(bytesLeft > 0)
					{
						if (bytesLeft < bufferSize)
						{
							_Read(discardBuffer, bytesLeft);
							break;
						}
						
						_Read(discardBuffer, bufferSize);
						bytesLeft -= bufferSize;
					}
				}
			#endregion

			#region _Writer
				private void _Writer()
				{
					var stream = _stream;
					if (stream == null)
						return;

					try
					{
						lock(DisposeLock)
						{
							if (WasDisposed)
								return;

							while (true)
							{
								if (_buffersToSend.Count == 0)
								{
									stream.Flush();
									Monitor.Wait(DisposeLock);

									if (WasDisposed)
										return;

									continue;
								}

								byte[] buffer = _buffersToSend.Dequeue();
								stream.Write(buffer, 0, buffer.Length);
							}
						}
					}
					catch(Exception exception)
					{
						Dispose(exception);
					}
				}
			#endregion
			
			#region _MainChannel
				private void _MainChannel()
				{
					var mainChannel = _mainChannel;
					if (mainChannel == null)
						return;

					try
					{
						var serializer = _CreateSerializer();
						while(true)
						{
							object obj = serializer.Deserialize(mainChannel);
							var action = (IChannelAction)obj;
							action.Run(this);
						}
					}
					catch(Exception exception)
					{
						Dispose(exception);
					}
				}
			#endregion
			
			#region _CreateSerializer
				private static BinarySerializer _CreateSerializer()
				{
					var serializer = new BinarySerializer();
					serializer.Register(ClassChannelCreatedSerializer.Instance);
					serializer.Register(ChannelAssociatedSerializer.Instance);
					serializer.Register(ChannelRemovedSerializer.Instance);
					return serializer;
				}
			#endregion
		#endregion
		#region Events
			/// <summary>
			/// Event called when Dispose() has just finished.
			/// </summary>
			public event EventHandler Disposed;
			
			/// <summary>
			/// Event that is invoked when the remote side creates a new channel.
			/// </summary>
			public event EventHandler<ChannelCreatedEventArgs> ChannelCreated;
		#endregion
		
		#region Nested classes
			private interface IChannelAction
			{
				void Run(StreamChanneller channeller);
			}
			private sealed class ClassChannelCreated:
				IChannelAction
			{
				internal int SenderChannelId;
				internal object Data;

				public void Run(StreamChanneller channeller)
				{
						int localChannelId = Interlocked.Increment(ref channeller._nextChannelId);
								
						StreamChannel channel = new StreamChannel(channeller);
						channel._id = localChannelId;
						channel._remoteId = SenderChannelId;

						ChannelAssociated associated = new ChannelAssociated();
						associated.SenderChannelId = localChannelId;
						associated.ReceiverChannelId = SenderChannelId;

						lock(channeller.DisposeLock)
						{
							if (channeller.WasDisposed)
								return;

							channeller._channels.Add(localChannelId, channel);
								
							channeller._mainChannelSerializer.Serialize(channeller._mainChannel, associated);
							channeller._mainChannel.Flush();
						}
								
						ChannelCreatedEventArgs args = new ChannelCreatedEventArgs();
						args.Channel = channel;
						args.Data = Data;
								
						UnlimitedThreadPool.Run
						(
							(args2) =>
							{
								using(args2.Channel)
									channeller.ChannelCreated(this, args2);
							},
							args
						);
				}
			}
			private sealed class ClassChannelCreatedSerializer:
				ItemSerializer<ClassChannelCreated>
			{
				internal static readonly ClassChannelCreatedSerializer Instance = new ClassChannelCreatedSerializer();

				public override void Serialize(ConfigurableSerializerBase serializer, ClassChannelCreated item)
				{
					serializer.InnerSerialize(item.Data);
					serializer.Stream.WriteCompressedInt32(item.SenderChannelId);
				}
				public override ClassChannelCreated Deserialize(ConfigurableSerializerBase deserializer)
				{
					var result = new ClassChannelCreated();
					result.Data = deserializer.InnerDeserialize();
					result.SenderChannelId = deserializer.Stream.ReadCompressedInt32();
					return result;
				}
			}
			
			private sealed class ChannelRemoved:
				IChannelAction
			{
				internal int ReceiverChannelId;

				public void Run(StreamChanneller channeller)
				{
					StreamChannel channel;
					lock(channeller.DisposeLock)
					{
						if (channeller.WasDisposed)
							return;

						channeller._channels.TryGetValue(ReceiverChannelId, out channel);
					}
										
					if (channel != null)
						channel._BeginDispose();
				}
			}
			private sealed class ChannelRemovedSerializer:
				ItemSerializer<ChannelRemoved>
			{
				internal static readonly ChannelRemovedSerializer Instance = new ChannelRemovedSerializer();

				public override void Serialize(ConfigurableSerializerBase serializer, ChannelRemoved item)
				{
					serializer.Stream.WriteCompressedInt32(item.ReceiverChannelId);
				}
				public override ChannelRemoved Deserialize(ConfigurableSerializerBase deserializer)
				{
					var result = new ChannelRemoved();
					result.ReceiverChannelId = deserializer.Stream.ReadCompressedInt32();
					return result;
				}
			}
			
			private sealed class ChannelAssociated:
				IChannelAction
			{
				internal int ReceiverChannelId;
				internal int SenderChannelId;

				public void Run(StreamChanneller channeller)
				{
					StreamChannel channel = null;
					lock(channeller.DisposeLock)
					{
						if (channeller.WasDisposed)
							return;

						channel = channeller._channels[ReceiverChannelId];
						channel._remoteId = SenderChannelId;
						channeller._awaitingChannels[channel._id].Set();
					}
				}
			}
			private sealed class ChannelAssociatedSerializer:
				ItemSerializer<ChannelAssociated>
			{
				internal static readonly ChannelAssociatedSerializer Instance = new ChannelAssociatedSerializer();

				public override void Serialize(ConfigurableSerializerBase serializer, ChannelAssociated item)
				{
					var stream = serializer.Stream;
					stream.WriteCompressedInt32(item.ReceiverChannelId);
					stream.WriteCompressedInt32(item.SenderChannelId);
				}
				public override ChannelAssociated Deserialize(ConfigurableSerializerBase deserializer)
				{
					var stream = deserializer.Stream;

					var result = new ChannelAssociated();
					result.ReceiverChannelId = stream.ReadCompressedInt32();
					result.SenderChannelId = stream.ReadCompressedInt32();
					return result;
				}
			}
		#endregion

		#region IChanneller Members
			IChannel IChanneller.CreateChannel(object createData)
			{
				return CreateChannel(createData);
			}
		#endregion
	}
}
